{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Part 6 - Saving and Loading Models.ipynb",
      "version": "0.3.2",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "cells": [
    {
      "metadata": {
        "id": "R4xTHH-ilnmf",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Saving and Loading Models\n",
        "\n",
        "In this notebook, I'll show you how to save and load models with PyTorch. This is important because you'll often want to load previously trained models to use in making predictions or to continue training on new data."
      ]
    },
    {
      "metadata": {
        "id": "EzGz9PfjnF6v",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 341
        },
        "outputId": "7bc67aee-80ee-4e9a-8b0a-76de24b13a14"
      },
      "cell_type": "code",
      "source": [
        "!pip install torch torchvision"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting torch\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/49/0e/e382bcf1a6ae8225f50b99cc26effa2d4cc6d66975ccf3fa9590efcbedce/torch-0.4.1-cp36-cp36m-manylinux1_x86_64.whl (519.5MB)\n",
            "\u001b[K    100% |████████████████████████████████| 519.5MB 27kB/s \n",
            "tcmalloc: large alloc 1073750016 bytes == 0x59002000 @  0x7f2c42c472a4 0x591a07 0x5b5d56 0x502e9a 0x506859 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x502209 0x502f3d 0x506859 0x504c28 0x502540 0x502f3d 0x507641 0x504c28 0x502540 0x502f3d 0x507641\n",
            "\u001b[?25hCollecting torchvision\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/ca/0d/f00b2885711e08bd71242ebe7b96561e6f6d01fdb4b9dcf4d37e2e13c5e1/torchvision-0.2.1-py2.py3-none-any.whl (54kB)\n",
            "\u001b[K    100% |████████████████████████████████| 61kB 20.1MB/s \n",
            "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.14.6)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from torchvision) (1.11.0)\n",
            "Collecting pillow>=4.1.1 (from torchvision)\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/62/94/5430ebaa83f91cc7a9f687ff5238e26164a779cca2ef9903232268b0a318/Pillow-5.3.0-cp36-cp36m-manylinux1_x86_64.whl (2.0MB)\n",
            "\u001b[K    100% |████████████████████████████████| 2.0MB 4.9MB/s \n",
            "\u001b[?25hInstalling collected packages: torch, pillow, torchvision\n",
            "  Found existing installation: Pillow 4.0.0\n",
            "    Uninstalling Pillow-4.0.0:\n",
            "      Successfully uninstalled Pillow-4.0.0\n",
            "Successfully installed pillow-5.3.0 torch-0.4.1 torchvision-0.2.1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "IfPf7TM_rIJ4",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Helper Functions\n",
        "class Network(nn.Module):\n",
        "    def __init__(self, input_size, output_size, hidden_layers, drop_p=0.5):\n",
        "        ''' Builds a feedforward network with arbitrary hidden layers.\n",
        "        \n",
        "            Arguments\n",
        "            ---------\n",
        "            input_size: integer, size of the input layer\n",
        "            output_size: integer, size of the output layer\n",
        "            hidden_layers: list of integers, the sizes of the hidden layers\n",
        "        \n",
        "        '''\n",
        "        super().__init__()\n",
        "        # Input to a hidden layer\n",
        "        self.hidden_layers = nn.ModuleList([nn.Linear(input_size, hidden_layers[0])])\n",
        "        \n",
        "        # Add a variable number of more hidden layers\n",
        "        layer_sizes = zip(hidden_layers[:-1], hidden_layers[1:])\n",
        "        self.hidden_layers.extend([nn.Linear(h1, h2) for h1, h2 in layer_sizes])\n",
        "        \n",
        "        self.output = nn.Linear(hidden_layers[-1], output_size)\n",
        "        \n",
        "        self.dropout = nn.Dropout(p=drop_p)\n",
        "        \n",
        "    def forward(self, x):\n",
        "        ''' Forward pass through the network, returns the output logits '''\n",
        "        \n",
        "        for each in self.hidden_layers:\n",
        "            x = F.relu(each(x))\n",
        "            x = self.dropout(x)\n",
        "        x = self.output(x)\n",
        "        \n",
        "        return F.log_softmax(x, dim=1)\n",
        "\n",
        "\n",
        "def validation(model, testloader, criterion):\n",
        "    accuracy = 0\n",
        "    test_loss = 0\n",
        "    for images, labels in testloader:\n",
        "\n",
        "        images = images.resize_(images.size()[0], 784)\n",
        "\n",
        "        output = model.forward(images)\n",
        "        test_loss += criterion(output, labels).item()\n",
        "\n",
        "        ## Calculating the accuracy \n",
        "        # Model's output is log-softmax, take exponential to get the probabilities\n",
        "        ps = torch.exp(output)\n",
        "        # Class with highest probability is our predicted class, compare with true label\n",
        "        equality = (labels.data == ps.max(1)[1])\n",
        "        # Accuracy is number of correct predictions divided by all predictions, just take the mean\n",
        "        accuracy += equality.type_as(torch.FloatTensor()).mean()\n",
        "\n",
        "    return test_loss, accuracy\n",
        "\n",
        "\n",
        "def train(model, trainloader, testloader, criterion, optimizer, epochs=5, print_every=40):\n",
        "    \n",
        "    steps = 0\n",
        "    running_loss = 0\n",
        "    for e in range(epochs):\n",
        "        # Model in training mode, dropout is on\n",
        "        model.train()\n",
        "        for images, labels in trainloader:\n",
        "            steps += 1\n",
        "            \n",
        "            # Flatten images into a 784 long vector\n",
        "            images.resize_(images.size()[0], 784)\n",
        "            \n",
        "            optimizer.zero_grad()\n",
        "            \n",
        "            output = model.forward(images)\n",
        "            loss = criterion(output, labels)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            \n",
        "            running_loss += loss.item()\n",
        "\n",
        "            if steps % print_every == 0:\n",
        "                # Model in inference mode, dropout is off\n",
        "                model.eval()\n",
        "                \n",
        "                # Turn off gradients for validation, will speed up inference\n",
        "                with torch.no_grad():\n",
        "                    test_loss, accuracy = validation(model, testloader, criterion)\n",
        "                \n",
        "                print(\"Epoch: {}/{}.. \".format(e+1, epochs),\n",
        "                      \"Training Loss: {:.3f}.. \".format(running_loss/print_every),\n",
        "                      \"Test Loss: {:.3f}.. \".format(test_loss/len(testloader)),\n",
        "                      \"Test Accuracy: {:.3f}\".format(accuracy/len(testloader)))\n",
        "                \n",
        "                running_loss = 0\n",
        "                \n",
        "                # Make sure dropout and grads are on for training\n",
        "                model.train()\n",
        "    \n",
        "    \n",
        "def imshow(image, ax=None, title=None, normalize=True):\n",
        "    \"\"\"Imshow for Tensor.\"\"\"\n",
        "    if ax is None:\n",
        "        fig, ax = plt.subplots()\n",
        "    image = image.numpy().transpose((1, 2, 0))\n",
        "\n",
        "    if normalize:\n",
        "        mean = np.array([0.485, 0.456, 0.406])\n",
        "        std = np.array([0.229, 0.224, 0.225])\n",
        "        image = std * image + mean\n",
        "        image = np.clip(image, 0, 1)\n",
        "\n",
        "    ax.imshow(image)\n",
        "    ax.spines['top'].set_visible(False)\n",
        "    ax.spines['right'].set_visible(False)\n",
        "    ax.spines['left'].set_visible(False)\n",
        "    ax.spines['bottom'].set_visible(False)\n",
        "    ax.tick_params(axis='both', length=0)\n",
        "    ax.set_xticklabels('')\n",
        "    ax.set_yticklabels('')\n",
        "\n",
        "    return ax    "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "A1mHxm_9lnmr",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "%matplotlib inline\n",
        "%config InlineBackend.figure_format = 'retina'\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import torch\n",
        "from torch import nn\n",
        "from torch import optim\n",
        "import torch.nn.functional as F\n",
        "from torchvision import datasets, transforms\n",
        "import numpy as np\n",
        "\n",
        "#import helper\n",
        "#import fc_model"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "pEfnyK2Dlnm-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 124
        },
        "outputId": "47d4a134-d8f4-4498-cd3e-924fd11453c8"
      },
      "cell_type": "code",
      "source": [
        "# Define a transform to normalize the data\n",
        "transform = transforms.Compose([transforms.ToTensor(),\n",
        "                                transforms.Normalize((0.5,), (0.5,))])\n",
        "# Download and load the training data\n",
        "trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=transform)\n",
        "trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)\n",
        "\n",
        "# Download and load the test data\n",
        "testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=transform)\n",
        "testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz\n",
            "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz\n",
            "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz\n",
            "Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz\n",
            "Processing...\n",
            "Done!\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "OrO4-lD-lnnM",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Here we can see one of the images."
      ]
    },
    {
      "metadata": {
        "id": "qB44NaCHlnnQ",
        "colab_type": "code",
        "outputId": "c7fc3403-30cc-4d1d-e2b2-ba97dd6280ac",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 266
        }
      },
      "cell_type": "code",
      "source": [
        "image, label = next(iter(trainloader))\n",
        "#helper.imshow(image[0,:]);\n",
        "imshow(image[0,:])"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.axes._subplots.AxesSubplot at 0x7f3304d3d978>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAc8AAAHPCAYAAAA1eFErAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAWJQAAFiUBSVIk8AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAE+NJREFUeJzt3XuQnWV9wPHn7Ek2l91cSITQclOS\nGRJFInKJhFQwtEqptTOItdwEsWrbGWmHaemUjv5RZ+wMdZgWO50qiNYLVFsv4NAqtnFsEyBBcBKQ\nQKUBkmAgkJC9TSC7e07/aYZOs+855/c82Zxd8vn8w0ze57fPcnaz37wz+56n1mw2EwDQuZ5ufwIA\nMN2IJwAEiScABIknAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAEiScABIknAATNyB1cu2Zl\n5XEs69ZvPrgm98MflbxucV6zPF63PF63uKn+mq1bv7mWM+fOEwCCxBMAgsQTAILEEwCCxBMAgmrN\nZuUvzbaTPQgAU4TftgWAI6HkOc/Ka1P9uZ6pyusW5zXL43XL43WLm+qv2cHPL8qdJwAEiScABIkn\nAASJJwAEiScABIknAASJJwAEZT/nCTBZarWsN305LAreda3Yuy68MHt227ZtRXs/u3170fzRxp0n\nAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAEOZIM\nmFBPwbFgjcJjvbp5LFg3rVi+Int2Xn9/0d7dPJKs5Ai6bn2vuPMEgCDxBIAg8QSAIPEEgCDxBIAg\n8QSAIPEEgCDxBIAg8QSAIPEEgCDxBIAg8QSAIPEEgCDxBIAg8QSAIOd5wiTq6cn/92m9Xi/ae3R0\ntGi+9EzObjl+yZKi+VNPPbXtmtXnnTfhn594wglFezebjezZ+fMXFO09f9687NnBoaGivfNP80yp\nW9+l7jwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8\nASDIkWTQQr3gSLGUUhpv5B8x1SiYPRwWL1qUPbt8+fKivRcdc0z27OJFi4v2njt3bts1bztj5YR/\nfmD0QNHeO3c+lz178sknFe19/urzs2f/9QffL9p7Oh5/584TAILEEwCCxBMAgsQTAILEEwCCxBMA\ngsQTAILEEwCCxBMAgsQTAILEEwCCxBMAgsQTAILEEwCCxBMAgpznCS2UnMdZatnSpUXzy09rf6bm\ney/5jcprCxcuyN673lPPnk0ppd7e3uzZAwfKztTc9vS27DX/tm5d0d4lPnLth4vmT33Tm7Jn5/X3\nF60ZGh7O3rtb3HkCQJB4AkCQeAJAkHgCQJB4AkBQrdls5s5mDwLAFFHLGXLnCQBB2c95rl2zsvLa\nuvWb267hUF63uNfzazaZz3necNNnUkop3fKZmyrXHK3Peb6w+4XKazd+6q9SSind/Bd/MuH16fyc\n54wZ+Y/9f+VrX628dvf3N6SUUvqti8+vXNPN5zwP/gyJcucJAEHiCQBB4gkAQeIJAEHiCQBB4gkA\nQY4kgxYWLMh/XCOllNZe+K7s2WOPfUPR3rVa+2e/TzjhlyuvjY6OZu89MjKSPZtSSo8/sTV79uGH\nHy7au9HijWMOPqry/fvuK9pjMsyfP69ovuRxkQsvuKBozffuvTd7725x5wkAQeIJAEHiCQBB4gkA\nQeIJAEHiCQBB4gkAQeIJAEHiCQBB4gkAQeIJAEHiCQBB4gkAQeIJAEHiCQBBzvOkI52cDdlKs8UZ\nie3MmTOnaM3F73539t4nnnhi9mxKZa/bvn37ivY+cOBA2zUDAwOV137yyCPZez/11FPZs+R5ufD7\npV6vZ8+edNLJRWvevGJF9t6Pb80/+7WEO08ACBJPAAgSTwAIEk8ACBJPAAgSTwAIEk8ACBJPAAgS\nTwAIEk8ACBJPAAgSTwAIEk8ACBJPAAhyJBlT3v79+4vW9PX1Ze/90ksvZc+mlNLPfvZ49uzTzzxd\ntPfQ8HDltY9d/+cppZT+8ZvfLNrj9Wj27NnZa045uf3RXK0sXrw4e7bZyD/2L6WUGrVG9uzY6Ctt\n17z6SvWaSy7+9ey9HUkGANOEeAJAkHgCQJB4AkCQeAJAkHgCQJB4AkCQeAJAkHgCQJB4AkCQeAJA\nkHgCQJB4AkCQeAJAkHgCQJDzPOlIs1l2VmA3fe3OO7Nnjz322KK9X3zxxaL5blq0aFH27LKlS4v2\nXrhwYfbsLy05vmjvTs5/ve6aayf889K/J/39+WfPvrxvX9HeKdUmd7TFmn0Fn/vJhWeo5nLnCQBB\n4gkAQeIJAEHiCQBB4gkAQeIJAEHiCQBB4gkAQeIJAEHiCQBB4gkAQeIJAEHiCQBB4gkAQY4k44jo\nqeUfd9To4nFopUeKLTrmmOzZi9auLdq7r6+/7ZprP3RN5bVawddsZGQkezallAYHB7Nntz75RNHe\nO3bsqLx22dW/n1JK6Tt3f3fC68OF/99DQ0PZs7960UVFey8/7bTs2U4+77GxscprPT317L3fsGhx\n9mwJd54AECSeABAkngAQJJ4AECSeABAkngAQVGvmPwbQvecHAODwyHomy50nAARlv0nC2jUrK6+t\nW7+57RoO9Xp+3SbrTRKm+ms2Vd8k4UMfuyGllNJXvnBL5Zrp+iYJe1/eW7R3qzdJ+Po//yCllNKV\nl71nwuveJOFQ13z8j1NKKf3D5z9buabkTRI2b96cPZtSSp/92y9nzbnzBIAg8QSAIPEEgCDxBIAg\n8QSAIPEEgCDnefK6d96qVdmzy5evKNp77tw52bN79uwp2nvr1q1Fa3629fHsvYeHh7Nnp4Ndzz/f\n7U/hEPWewnuhgre9GR8fL1rT38HZs1X2v7I/e7aEO08ACBJPAAgSTwAIEk8ACBJPAAgSTwAIEk8A\nCBJPAAgSTwAIEk8ACBJPAAgSTwAIEk8ACBJPAAhyJBlHRK1Wyx9uFpyVlFJacvzx2bODg4NFe3/l\na1/Nnu3kmKdSGx/aNOl7TDedfK9WrWkWfq+WmDu3r2h+dHQ0e7bkNUsppZ56/n3cc889lz1bwp0n\nAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAEiScA\nBDnPkyOi0cVzDr97991d25vpp5MzOavW9M7sLdr7fb/53uzZRYsWFe09PDKcPdvJeZ6tXteRkZHs\nvQeHhrJnS7jzBIAg8QSAIPEEgCDxBIAg8QSAIPEEgCDxBIAg8QSAIPEEgCDxBIAg8QSAIPEEgCDx\nBIAg8QSAIEeScUR0cswTTAWrzj03e82qc9rPttLJ0V5V9g3sK9p7Rj0/B41mo+2aVv9vAwOD2Xt3\niztPAAgSTwAIEk8ACBJPAAgSTwAIEk8ACBJPAAgSTwAIEk8ACBJPAAgSTwAIEk8ACBJPAAgSTwAI\nEk8ACHKeJ1PekiVLitYsO3Vp9t79/X3ZsymlNGNG/l+x3t5ZRXvPnNl+79++7LIW873ZexccS5lS\nSmnv3r3ZszN78z/vlFJ60ymntF1z3qpVE/75wz/9adHeG+6/P3u20Wh/pmYrJWeRnnP22W3XLJi/\noPLaM888m713t7jzBIAg8QSAIPEEgCDxBIAg8QSAoFqz2cydzR4EgCki63fD3XkCQFD2Q2hr16ys\nvLZu/ea2aziU121irZ7hvOtb96WUUrr8/e+uXOM5z0P93h99MqWU0t//9adbzHvO8/973wd/N6WU\n0j3fuH3C657zPNSlV348pZTSt7/++co1m7dsyd77Pzesz55N6bWfu1HuPAEgSDwBIEg8ASBIPAEg\nSDwBIEg8ASBIPAEgyHmeR5F6T/6/lcYLnyF76+mnZ8++59eqn+E86Oorrqy8Njw8nL133nuPvOb5\nF17Int29e3fR3uPj423XbN++o/La2PhY9t6zZpU9o7ryjDOyZ8fG8j/vlFL63r33Vl47+Jxn1Zr/\n3rataO9u2vjQpuzZ81evbrumt8Xztzt27szeu1vceQJAkHgCQJB4AkCQeAJAkHgCQJB4AkCQeAJA\nkHgCQJB4AkCQeAJAkHgCQJB4AkCQeAJAkHgCQJAjyY4ipceKlXj0sceyZ5999tnKa5e8/5qUUkpf\nuP22yjWDQ0PZe7/ePbhp46R83At+5Z1F8+Nj7Y9Tq/Ktb3+naO89e/e0XTNZR4/1FBwb2Cj8+93f\n15c922w2i9YMDg5k790t7jwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8\nASBIPAEgSDwBIEg8ASBIPAEgyHmeR1gn5/VVrSk9r2/F8hXZs81m2d7z5s3Lnp07Z07bNWe+7W2V\n1+Z0MF/lpT3tz3Zs5dnt27NnX3zxxaK9S71j1ars2befeWbR3t/4p29mz3ZyHudU1cm5mJNlZm9v\n9ux4o/35q63WDA+PZO/dLe48ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8ASBI\nPAEgSDwBIEg8ASBIPAEgyJFkR1gnx4qVHj1W5bnndmbPXnXllUV7lxwLNjQ01HbNaaedVnmtk2Pg\nqixbuix7NqWUzjnr7OzZwaHBor3/6+c/b7vmnLOrP79Wx7y1s+mhh7JnU0rpF7t2Fc2XqNVqh2VN\njpLv1fHx9seCtXLMwoX5e491cCRZizUHRg9k790t7jwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwB\nIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgyHmeR9jsWbOy17zy6qtFew92cC5mlfUb\n7i/ae6jgbMrjjjuu8trv/O9/N2/ZUrlmwYIF2XvPnzcvezallBYuyD8jsb+/v2jvlW89o2jN4GD+\n98uGB8q+X7qp2Wxmryk953OyzgntxOzZ+WfuHhgdPSxrphN3ngAQJJ4AECSeABAkngAQJJ4AEFTr\n5DfLKmQPAsAUkfUrzu48ASAo+znPtWtWVl5bt35z2zVHq1bPef7Lv29KKaV0yUXnTni99DnPEmd0\n8MxgK5P1nOdf3nJbSimlP7vho5VrputznvUZ9aK9x8fGK69d8ZE/TCmldOcX/6Zyzf5XXsne++t3\n3Zk9O5W1+9lW+pxmvZ7/NR8bGyva+80r3pw9e/7q1ZXXLv/wJ1JKKd31pc9Vrrnti7dn713q4Nc0\nyp0nAASJJwAEiScABIknAASJJwAEiScABDmS7Ai7+qqrs9c8uPHBor0ffeyx7Nktj1Yf+TXZnn7m\nmbZrNm7aNPmfyBH2jlWriuZXnTvxI0//V19/X+W1DQ88ULR/iZJHPgre+KVY6d7d/NxnzerNnh0d\nPXBY1kwn7jwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwBIEg8ASBIPAEgSDwB\nIEg8ASBIPAEgqGvnefYUnNdXrGDvRqNRtPX2Hduz15x/3uqivUvO8yzV05P/77TS13y6euMpbyya\n3717d9GarU9sLdq/RDfPtTxazZw5M3v2wKvtz+rsZM104s4TAILEEwCCxBMAgsQTAILEEwCCxBMA\ngsQTAILEEwCCxBMAgsQTAILEEwCCxBMAgsQTAILEEwCCunYkWaObRw51ce91P/pR5bU//VTrNZdd\nemnR3iuWL8+e3frEE0V7H63Hil14wQXZsyedeELR3nd8+cuV1z76iZtSSind98MfFu1RpfTIwa7+\nfDhK9c7szZ4dHRs9LGumE3eeABAkngAQJJ4AECSeABAkngAQJJ4AECSeABAkngAQJJ4AECSeABAk\nngAQJJ4AECSeABAkngAQJJ4AENS18zxnzMjful6vF+3d39efPdvXN7do78WLF7ddc/pb3jLhn/f0\nlP1b5+yzzsqe/cWuXUV7DwwMFM13yzvXrCmaP6fgNd/40ENFe+/Zu/ewrMnhPM48tcJzUEvUZ+T/\nXB0fb39ebydrphN3ngAQJJ4AECSeABAkngAQJJ4AECSeABAkngAQJJ4AECSeABAkngAQJJ4AECSe\nABAkngAQJJ4AEJR9LtjZb29/1FKrNcuWLc3dOvX29mbPplR2HFrJbEopNRvtj2qqet3GG+NFe9fr\n+Z/7Bz/wgaK9h4aGsmc7+byvuuLKymtz587J3rv06/3AgxuzZ9ffv6Fob4jom5t/3GKjg59NnayZ\nTtx5AkCQeAJAkHgCQJB4AkCQeAJAkHgCQFCt2Wz/6ESF7EEAmCJqOUPuPAEgKPsJ8Buvv67y2s23\n3tF2jTdJONTl112fUkrprjtunfB6N98koZb1b7PXTNabJFz3BzemlFK64+9urlzTzTdJ2LLl0ezZ\nyXyThHXrN6eUUlq7ZuWk7fF6NNmvW8n329jYWNHel1x8cfZsq5/J19/46ZRSSrfe/MnKNd+9557s\nvUsd/JpGufMEgCDxBIAg8QSAIPEEgCDxBIAg8QSAoOzfi/7JIw8XrelkfrLUCp67KHhTibZ7H3xU\n5fYv3THh9dmzZxftPXvWrPzZOfmPe6SU0oL587NnR0ZGKq8dfFTlx//x48o1AwMD2XsPFjxi022d\nfJ+3WlP6vU5cY7x7Z14+8eST2bM99XrbNY89/nj2x5+K3HkCQJB4AkCQeAJAkHgCQJB4AkCQeAJA\nkHgCQJB4AkCQeAJAkHgCQJB4AkCQeAJAkHgCQJB4AkBQ9pFk01k3j1rqZO+qNfv37y/au2h+376i\nvXft2lU0386OnTsn9eNPRyXfa3RHo4tfj21PPz2pH/+pp56a1I9/pLnzBIAg8QSAIPEEgCDxBIAg\n8QSAIPEEgCDxBIAg8QSAIPEEgCDxBIAg8QSAIPEEgCDxBIAg8QSAIPEEgKCa8/wAIMadJwAEiScA\nBIknAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAEiScABIknAASJJwAE/Q/Rr+29CgWO0AAA\nAABJRU5ErkJggg==\n",
            "text/plain": [
              "<matplotlib.figure.Figure at 0x7f33053bc208>"
            ]
          },
          "metadata": {
            "tags": [],
            "image/png": {
              "width": 231,
              "height": 231
            }
          }
        }
      ]
    },
    {
      "metadata": {
        "id": "OWeXOTe4lnnh",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# Train a network\n",
        "\n",
        "To make things more concise here, I moved the model architecture and training code from the last part to a file called `fc_model`. Importing this, we can easily create a fully-connected network with `fc_model.Network`, and train the network using `fc_model.train`. I'll use this model (once it's trained) to demonstrate how we can save and load models."
      ]
    },
    {
      "metadata": {
        "id": "l6CgBO4Blnnl",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# Create the network, define the criterion and optimizer\n",
        "\n",
        "model = Network(784, 10, [512, 256, 128])\n",
        "criterion = nn.NLLLoss()\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.001)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "B-aJZEsplnnt",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 839
        },
        "outputId": "0ef6c912-9fbf-49f8-c54d-2831a8115eb7"
      },
      "cell_type": "code",
      "source": [
        "train(model, trainloader, testloader, criterion, optimizer, epochs=2)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 1/2..  Training Loss: 1.751..  Test Loss: 1.033..  Test Accuracy: 0.591\n",
            "Epoch: 1/2..  Training Loss: 1.066..  Test Loss: 0.728..  Test Accuracy: 0.726\n",
            "Epoch: 1/2..  Training Loss: 0.834..  Test Loss: 0.676..  Test Accuracy: 0.745\n",
            "Epoch: 1/2..  Training Loss: 0.807..  Test Loss: 0.646..  Test Accuracy: 0.754\n",
            "Epoch: 1/2..  Training Loss: 0.769..  Test Loss: 0.649..  Test Accuracy: 0.748\n",
            "Epoch: 1/2..  Training Loss: 0.763..  Test Loss: 0.614..  Test Accuracy: 0.764\n",
            "Epoch: 1/2..  Training Loss: 0.699..  Test Loss: 0.573..  Test Accuracy: 0.781\n",
            "Epoch: 1/2..  Training Loss: 0.721..  Test Loss: 0.565..  Test Accuracy: 0.790\n",
            "Epoch: 1/2..  Training Loss: 0.640..  Test Loss: 0.551..  Test Accuracy: 0.795\n",
            "Epoch: 1/2..  Training Loss: 0.650..  Test Loss: 0.541..  Test Accuracy: 0.801\n",
            "Epoch: 1/2..  Training Loss: 0.630..  Test Loss: 0.530..  Test Accuracy: 0.799\n",
            "Epoch: 1/2..  Training Loss: 0.623..  Test Loss: 0.531..  Test Accuracy: 0.807\n",
            "Epoch: 1/2..  Training Loss: 0.586..  Test Loss: 0.560..  Test Accuracy: 0.797\n",
            "Epoch: 1/2..  Training Loss: 0.609..  Test Loss: 0.509..  Test Accuracy: 0.814\n",
            "Epoch: 1/2..  Training Loss: 0.592..  Test Loss: 0.513..  Test Accuracy: 0.814\n",
            "Epoch: 1/2..  Training Loss: 0.571..  Test Loss: 0.497..  Test Accuracy: 0.817\n",
            "Epoch: 1/2..  Training Loss: 0.599..  Test Loss: 0.489..  Test Accuracy: 0.821\n",
            "Epoch: 1/2..  Training Loss: 0.568..  Test Loss: 0.491..  Test Accuracy: 0.819\n",
            "Epoch: 1/2..  Training Loss: 0.566..  Test Loss: 0.495..  Test Accuracy: 0.818\n",
            "Epoch: 1/2..  Training Loss: 0.632..  Test Loss: 0.490..  Test Accuracy: 0.818\n",
            "Epoch: 1/2..  Training Loss: 0.570..  Test Loss: 0.507..  Test Accuracy: 0.816\n",
            "Epoch: 1/2..  Training Loss: 0.561..  Test Loss: 0.475..  Test Accuracy: 0.828\n",
            "Epoch: 1/2..  Training Loss: 0.532..  Test Loss: 0.480..  Test Accuracy: 0.821\n",
            "Epoch: 2/2..  Training Loss: 0.550..  Test Loss: 0.483..  Test Accuracy: 0.815\n",
            "Epoch: 2/2..  Training Loss: 0.540..  Test Loss: 0.493..  Test Accuracy: 0.822\n",
            "Epoch: 2/2..  Training Loss: 0.602..  Test Loss: 0.499..  Test Accuracy: 0.821\n",
            "Epoch: 2/2..  Training Loss: 0.578..  Test Loss: 0.497..  Test Accuracy: 0.817\n",
            "Epoch: 2/2..  Training Loss: 0.539..  Test Loss: 0.468..  Test Accuracy: 0.827\n",
            "Epoch: 2/2..  Training Loss: 0.521..  Test Loss: 0.478..  Test Accuracy: 0.826\n",
            "Epoch: 2/2..  Training Loss: 0.499..  Test Loss: 0.480..  Test Accuracy: 0.824\n",
            "Epoch: 2/2..  Training Loss: 0.537..  Test Loss: 0.458..  Test Accuracy: 0.839\n",
            "Epoch: 2/2..  Training Loss: 0.575..  Test Loss: 0.487..  Test Accuracy: 0.831\n",
            "Epoch: 2/2..  Training Loss: 0.529..  Test Loss: 0.467..  Test Accuracy: 0.826\n",
            "Epoch: 2/2..  Training Loss: 0.539..  Test Loss: 0.456..  Test Accuracy: 0.834\n",
            "Epoch: 2/2..  Training Loss: 0.528..  Test Loss: 0.460..  Test Accuracy: 0.836\n",
            "Epoch: 2/2..  Training Loss: 0.505..  Test Loss: 0.462..  Test Accuracy: 0.832\n",
            "Epoch: 2/2..  Training Loss: 0.532..  Test Loss: 0.449..  Test Accuracy: 0.836\n",
            "Epoch: 2/2..  Training Loss: 0.544..  Test Loss: 0.471..  Test Accuracy: 0.833\n",
            "Epoch: 2/2..  Training Loss: 0.526..  Test Loss: 0.465..  Test Accuracy: 0.829\n",
            "Epoch: 2/2..  Training Loss: 0.531..  Test Loss: 0.449..  Test Accuracy: 0.840\n",
            "Epoch: 2/2..  Training Loss: 0.537..  Test Loss: 0.452..  Test Accuracy: 0.836\n",
            "Epoch: 2/2..  Training Loss: 0.522..  Test Loss: 0.450..  Test Accuracy: 0.836\n",
            "Epoch: 2/2..  Training Loss: 0.515..  Test Loss: 0.449..  Test Accuracy: 0.835\n",
            "Epoch: 2/2..  Training Loss: 0.543..  Test Loss: 0.457..  Test Accuracy: 0.837\n",
            "Epoch: 2/2..  Training Loss: 0.522..  Test Loss: 0.445..  Test Accuracy: 0.838\n",
            "Epoch: 2/2..  Training Loss: 0.463..  Test Loss: 0.445..  Test Accuracy: 0.836\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "uQC2C7J7lnn2",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## Saving and loading networks\n",
        "\n",
        "As you can imagine, it's impractical to train a network every time you need to use it. Instead, we can save trained networks then load them later to train more or use them for predictions.\n",
        "\n",
        "The parameters for PyTorch networks are stored in a model's `state_dict`. We can see the state dict contains the weight and bias matrices for each of our layers."
      ]
    },
    {
      "metadata": {
        "id": "yMgeae8Dlnn7",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 285
        },
        "outputId": "cc295690-103f-4ada-dc4e-4a6bbd92d0f4"
      },
      "cell_type": "code",
      "source": [
        "print(\"Our model: \\n\\n\", model, '\\n')\n",
        "print(\"The state dict keys: \\n\\n\", model.state_dict().keys())"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Our model: \n",
            "\n",
            " Network(\n",
            "  (hidden_layers): ModuleList(\n",
            "    (0): Linear(in_features=784, out_features=512, bias=True)\n",
            "    (1): Linear(in_features=512, out_features=256, bias=True)\n",
            "    (2): Linear(in_features=256, out_features=128, bias=True)\n",
            "  )\n",
            "  (output): Linear(in_features=128, out_features=10, bias=True)\n",
            "  (dropout): Dropout(p=0.5)\n",
            ") \n",
            "\n",
            "The state dict keys: \n",
            "\n",
            " odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'hidden_layers.2.weight', 'hidden_layers.2.bias', 'output.weight', 'output.bias'])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "S-ysq2HtlnoH",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "The simplest thing to do is simply save the state dict with `torch.save`. For example, we can save it to a file `'checkpoint.pth'`."
      ]
    },
    {
      "metadata": {
        "id": "6iCTf5YGlnoO",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "torch.save(model.state_dict(), 'checkpoint.pth')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "B_E9LD5Llnod",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Then we can load the state dict with `torch.load`."
      ]
    },
    {
      "metadata": {
        "id": "R7MHqTajlnok",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "63ec597a-ad90-4010-cf45-f1d8c8c36681"
      },
      "cell_type": "code",
      "source": [
        "state_dict = torch.load('checkpoint.pth')\n",
        "print(state_dict.keys())"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'hidden_layers.2.weight', 'hidden_layers.2.bias', 'output.weight', 'output.bias'])\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "N_caIiPPlnox",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "And to load the state dict in to the network, you do `model.load_state_dict(state_dict)`."
      ]
    },
    {
      "metadata": {
        "id": "dRDHJ6EIlnoz",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "model.load_state_dict(state_dict)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "OZRRlLy3lno9",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Seems pretty straightforward, but as usual it's a bit more complicated. Loading the state dict works only if the model architecture is exactly the same as the checkpoint architecture. If I create a model with a different architecture, this fails."
      ]
    },
    {
      "metadata": {
        "id": "rl50gh7xlnpA",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 457
        },
        "outputId": "f9f854e8-a6a5-4221-8d03-c3823c05a82c"
      },
      "cell_type": "code",
      "source": [
        "# Try this\n",
        "model = Network(784, 10, [400, 200, 100])\n",
        "# This will throw an error because the tensor sizes are wrong!\n",
        "model.load_state_dict(state_dict)"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "error",
          "ename": "RuntimeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-20-9b1d83a68ac6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNetwork\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m784\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m400\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m200\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m100\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;31m# This will throw an error because the tensor sizes are wrong!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mload_state_dict\u001b[0;34m(self, state_dict, strict)\u001b[0m\n\u001b[1;32m    717\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror_msgs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    718\u001b[0m             raise RuntimeError('Error(s) in loading state_dict for {}:\\n\\t{}'.format(\n\u001b[0;32m--> 719\u001b[0;31m                                self.__class__.__name__, \"\\n\\t\".join(error_msgs)))\n\u001b[0m\u001b[1;32m    720\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    721\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mRuntimeError\u001b[0m: Error(s) in loading state_dict for Network:\n\tsize mismatch for hidden_layers.0.weight: copying a param of torch.Size([400, 784]) from checkpoint, where the shape is torch.Size([512, 784]) in current model.\n\tsize mismatch for hidden_layers.0.bias: copying a param of torch.Size([400]) from checkpoint, where the shape is torch.Size([512]) in current model.\n\tsize mismatch for hidden_layers.1.weight: copying a param of torch.Size([200, 400]) from checkpoint, where the shape is torch.Size([256, 512]) in current model.\n\tsize mismatch for hidden_layers.1.bias: copying a param of torch.Size([200]) from checkpoint, where the shape is torch.Size([256]) in current model.\n\tsize mismatch for hidden_layers.2.weight: copying a param of torch.Size([100, 200]) from checkpoint, where the shape is torch.Size([128, 256]) in current model.\n\tsize mismatch for hidden_layers.2.bias: copying a param of torch.Size([100]) from checkpoint, where the shape is torch.Size([128]) in current model.\n\tsize mismatch for output.weight: copying a param of torch.Size([10, 100]) from checkpoint, where the shape is torch.Size([10, 128]) in current model."
          ]
        }
      ]
    },
    {
      "metadata": {
        "id": "UhoLqcfqlnpP",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "This means we need to rebuild the model exactly as it was when trained. Information about the model architecture needs to be saved in the checkpoint, along with the state dict. To do this, you build a dictionary with all the information you need to compeletely rebuild the model."
      ]
    },
    {
      "metadata": {
        "id": "Dg3F3nqFlnpT",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "checkpoint = {'input_size': 784,\n",
        "              'output_size': 10,\n",
        "              'hidden_layers': [each.out_features for each in model.hidden_layers],\n",
        "              'state_dict': model.state_dict()}\n",
        "\n",
        "torch.save(checkpoint, 'checkpoint.pth')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ePAlJ9Pmlnpl",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "Now the checkpoint has all the necessary information to rebuild the trained model. You can easily make that a function if you want. Similarly, we can write a function to load checkpoints. "
      ]
    },
    {
      "metadata": {
        "id": "NmRBJzxJlnpo",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def load_checkpoint(filepath):\n",
        "    checkpoint = torch.load(filepath)\n",
        "    model = Network(checkpoint['input_size'],\n",
        "                             checkpoint['output_size'],\n",
        "                             checkpoint['hidden_layers'])\n",
        "    model.load_state_dict(checkpoint['state_dict'])\n",
        "    \n",
        "    return model"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "y41JNzsNlnp0",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 178
        },
        "outputId": "075c6a7a-3a64-4633-ed6f-875f753eb823"
      },
      "cell_type": "code",
      "source": [
        "model = load_checkpoint('checkpoint.pth')\n",
        "print(model)"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Network(\n",
            "  (hidden_layers): ModuleList(\n",
            "    (0): Linear(in_features=784, out_features=400, bias=True)\n",
            "    (1): Linear(in_features=400, out_features=200, bias=True)\n",
            "    (2): Linear(in_features=200, out_features=100, bias=True)\n",
            "  )\n",
            "  (output): Linear(in_features=100, out_features=10, bias=True)\n",
            "  (dropout): Dropout(p=0.5)\n",
            ")\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}